#   Copyright 2023-2025, Jianbo Zhu, Jingyu Li, Peng-Fei Liu
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.


import numpy as np
from io import StringIO
from collections import OrderedDict
from .misc import __prog__


class Cell():
    '''
    A pythonic POSCAR class
    
    Attribute:
        basis: basis vectors, (3,3) np.ndarray
        sites: atomic positions in fractional coordinates, a OrderedDict:
            'A': [[f_a_1, f_b_1, f_c_1],
                  [f_a_2, f_b_2, f_c_2],
                  ...,
                 ]
            'B': [[f_a_1, f_b_1, f_c_1],
                  [f_a_2, f_b_2, f_c_2],
                  ...,
                 ]
            ...
    '''
    def __init__(self, basis=None, sites=None):
        if basis is None:
            self.basis = np.identity(3)
        else:
            self.basis = np.asarray(basis)

        if sites is None:
            self.sites = OrderedDict()
        elif isinstance(sites, OrderedDict):
            self.sites = sites
        else:
            self.sites = OrderedDict(sites)

    @classmethod
    def from_poscar(cls, poscar='POSCAR'):
        '''
        Create a Cell object from POSCAR file.
        '''
        # read file
        with open(poscar, 'r') as f:
            f.readline()        # comment header
            
            # parse basis
            scale = float(f.readline().strip())
            basis = [[float(i) for i in f.readline().split()[:3]] for _ in range(3)]
            if scale < 0:
                # volume mode
                # scale factor is refined for later use
                scale = (-1)*np.cbrt(scale/np.linalg.det(basis))
                basis = scale * np.array(basis)
            else:
                basis = scale * np.array(basis)
            
            # parse sites
            elts = f.readline().strip().split()
            nums = [int(i) for i in f.readline().split()]
            postype = f.readline().strip()[0]       # fisrt letter
            sites = OrderedDict() # num, site_list, label_list
            for num, elt in zip(nums, elts):
                for _ in range(num):
                    pos = [float(i) for i in f.readline().split()[:3]]
                    if elt in sites:
                        sites[elt].append(pos)
                    else:
                        sites[elt] = [pos, ]
            if postype in 'sS':
                # selective dynamics mode
                err_info = 'POSCAR with selective dynamics mode is not supported'
                raise NotImplementedError(err_info)
            elif postype in 'cCkK':
                # the cartesian mode, convert to direct mode
                trans = np.linalg.inv(basis)
                for elt, site in sites.items():
                    sites[elt] = [np.array(pos) @ trans for pos in site]
            else:
                # direct, fractional coordinates
                for elt, site in sites.items():
                    sites[elt] = [np.array(pos) for pos in site]
        
        # create Cell object
        return cls(basis, sites)

    def write(self, poscar=f'POSCAR.{__prog__}', header=None, fmt='%22.15f', eps_zero=1e-14):
        '''
        Write into POSCAR (only support fractional coordinates).
        '''
        if eps_zero is None:
            fseq = lambda _seq: ' '.join(fmt % x for x in _seq)
        else:
            fseq = lambda _seq: ' '.join(fmt % (0 if abs(x)<eps_zero else x) for x in _seq)

        basis = self.basis
        sites = self.sites
        lines = [f"{header or f'POSCAR generated by {__prog__}'}\n", ]
        lines.append(f'{1:8.3f}\n')
        for basis_i in basis:
            lines.append(fseq(basis_i)+'\n')
        elts = [f' {k:>4s}' for k in sites.keys()]
        nums = [f' {len(v):>4d}' for v in sites.values()]
        lines.append(''.join(elts)+'\n')
        lines.append(''.join(nums)+'\n')
        lines.append('Direct\n')
        for elt, idx, pos in self.all_pos():
            lines.append(f'{fseq(pos)}    {elt}{idx}\n')
        with open(poscar, 'w') as f:
            f.writelines(lines)
    
    def pop(self, atom, idx=1):
        '''
        Remove atom_idx

        Parameters
        ----------
        atom : str
            Type of atom
        idx : int, optional
            The index of pop atom (index start from 1), by default 1

        Returns
        -------
        ndarray in shape (3,)
            The postion of pop atom
        '''
        if atom in self.sites:
            pos = self.sites[atom].pop(idx-1)
            if len(self.sites[atom]) == 0:
                del self.sites[atom]
        else:
            raise RuntimeError(f'Failed to locate {atom}')
        return pos

    def index(self, atom, idx=1):
        '''
        Get global index (1-start) from atom & idx (1-start)
        '''
        for index_, (atom_, idx_, *_) in enumerate(self.all_pos(), start=1):
            if (atom == atom_) and (idx == idx_):
                return index_
        else:
            return -1

    def insert(self, atom, pos, idx=1, tohead=True):
        '''
        Insert a new atom at specified postion

        Parameters
        ----------
        atom : str
            Type of atom
        pos : ndarray with shape (3,)
            The position of new atom will be inserted
        idx : int, optional
            The index of order to be inserted (start from 1), by default 1
        tohead : bool, optional
            Whether to put the species first if it is new, by default True
        '''
        if atom in self.sites:
            self.sites[atom].insert(idx-1, pos)
        else:
            self.sites[atom] = [pos,]
            if tohead:
                self.sites.move_to_end(atom, last=False)

    def get_volume(self):
        '''
        Calculate volume of cell
        '''
        basis = np.array(self.basis)
        volume = np.linalg.det(basis)
        return volume
    
    def get_natom(self):
        '''
        Get the total number of atoms in the cell
        '''
        return sum(len(site) for site in self.sites.values())
    
    def get_pos(self, numbers):
        '''
        Get (atom, idx, pos) combination from given numbers (start from 1)

        Parameters
        ----------
        numbers : List[int]
            List of global idx (start from 1)

        Returns
        -------
        list
            [(atom, idx, pos), ...]
        '''
        all_pos = list(self.all_pos())
        return [all_pos[round(i)-1] for i in numbers]
    
    def get_dist(self, pos):
        '''
        Calculate distance between given positions and all sites in cell

        Parameters
        ----------
        pos : List[pos]
            List of positions to calculate distances

        Returns
        -------
        ndarray
            A ndarray of distances in shape (N_pos, N_cell_sites)
        '''
        
        # get all site positons
        poss = np.vstack(list(self.sites.values()))
        pp1 = np.array(poss).reshape((-1, 1, 3))    # shape: (N1, 1, 3)
        
        pp2 = np.atleast_2d(pos)                    # shape: (N2, 3)
        if pp2.shape[-1] != 3:
            raise ValueError('The shape of `pos` must is (..., 3)')
        elif pp2.size == 0:
            raise ValueError('An empty `pos` is given!')
        
        c1, c2, c3 = np.mgrid[-1:2,-1:2,-1:2]
        cc = np.c_[c1.flatten(),c2.flatten(),c3.flatten()]
        cc = np.reshape(cc, (-1, 1, 1, 3))          # shape: (27, 1, 1, 3)
        
        # find the nearest site
        dr = (pp2 - pp1 + cc) @ np.array(self.basis)
        dists = np.linalg.norm(dr, ord=2, axis=-1)  # shape: (27, N1, N2)
        dmin = np.min(dists, axis=0)                # shape: (N1, N2)
        return np.transpose(dmin)                   # shape: (N2, N1)
    
    def loc_pos(self, pos):
        '''
        Locate the nearest site in cell to corresponding given position

        Parameters
        ----------
        pos : List[pos]
            List of positions to be located at.

        Returns
        -------
        list
            [(atom, idx, pos), ...]
        '''
        dmin = self.get_dist(pos)
        ix = np.argmin(dmin, axis=-1)
        return self.get_pos(ix+1)
    
    def all_pos(self, atoms=None):
        '''
        Yield all (atom, idx, pos) combination

        Parameters
        ----------
        atoms : tuple or None, optional
            Specifies which atoms go into the iterator. If None (by default), all
            atoms will go into iterator.

        Yields
        ------
        tuple
            (atom, idx, pos)
        '''
        if atoms:
            for atom in atoms:
                for idx, pos in enumerate(self.sites[atom], start=1):
                    yield (atom, idx, pos)
        else:
            for atom, site in self.sites.items():
                for idx, pos in enumerate(site, start=1):
                    yield (atom, idx, pos)


def read_energy(outcar='OUTCAR', average=False):
    '''
    Read final energy from OUTCAR.

    '''
    with open(outcar, 'r') as f:
        data = f.readlines()
        for line in reversed(data):
            if 'sigma' in line:
                energy = float(line.rstrip().split()[-1])
                break
        if average:
            for line in data:
                if 'NIONS' in line:
                    natom = int(line.strip().split()[-1])
                    break
            energy /= natom
    return energy


def read_ewald(outcar='OUTCAR'):
    '''
    Read final Ewald from OUTCAR.

    '''
    with open(outcar, 'r') as f:
        data = f.readlines()
        for line in reversed(data):
            if 'Ewald energy   TEWEN' in line:
                ewald = float(line.rstrip().split('=')[-1])
                break
    return abs(ewald)


def read_pot(outcar='OUTCAR'):
    '''
    Read final site electrostatic potentials from OUTCAR.

    '''
    with open(outcar, 'r') as f:
        data = f.readlines()
        for idx, line in enumerate(reversed(data)):
            if 'electrostatic' in line:
                break 
    pot = []
    for line in data[2-idx:]:
        line = line.rstrip()
        if len(line) > 0:
            while len(line) > 0:
                pot.append(float(line[8:17]))
                line = line[17:]
        else:
            break 
    return pot


def read_volume(outcar='OUTCAR'):

    '''
    Read volume in A^3 from OUTCAR file
    '''
    with open(outcar, 'r') as f:
        data = f.readlines()
        for line in reversed(data):
            if 'volume' in line:
                break
    volume = float(line.strip().split()[-1])
    return volume


def read_epsilon(outcar='OUTCAR', isNumeric=False):
    '''
    Read the static dielectric properties from OUTCAR

    Returns
    -------
    Category, tensor, average

    '''
    target = 'STATIC DIELECTRIC'
    datalines = []
    with open(outcar, 'r') as f:
        line = f.readline()
        while line:
            if target in line:
                values = []
                if isNumeric:
                    f.readline()
                    for _ in range(3):
                        iline = f.readline()
                        value = iline.strip().split()
                        values.append(list(map(float, value)))
                else:
                    for _ in range(6):
                        values.append(f.readline().strip())
                datalines.append((line.strip(), values))
            line = f.readline()
    return datalines
    

def read_eigval(eigenval='EIGENVAL'):
    '''
    Read EIGENVAL file

    Parameters
    ----------
    eigenval : str, optional
        Filename of EIGENVAL. The default is 'EIGENVAL'.

    Returns
    -------
    (ele_num, kpt_num, eig_num), (kpts, kptw), (energy, weight)
    *_num: scalar
    kpts: (Nkpt,3)
    kptw: (Nkpt,)
    energy & weight: (Nbd, Nkpt)

    '''
    with open(eigenval, 'r') as f:
        data = f.readlines()
    ele_num, kpt_num, eig_num = map(int, data[5].rstrip().split())
    kptdata = np.loadtxt(StringIO(''.join(data[7::eig_num+2])))
    kpts = kptdata[:,:3]  # shape of (Nkpt,3)
    kptw = kptdata[:,3]  # shape of (Nkpt,)
    energy = []   # shape of (Nbd, Nkpt)
    weight = []   # shape of (Nbd, Nkpt)
    for i in range(eig_num):
        ei,wi = np.loadtxt(StringIO(
            ''.join(data[8+i::eig_num+2])),
            usecols=(1, 2), unpack=True)
        energy.append(ei)
        weight.append(wi)
    energy = np.vstack(energy)
    weight = np.vstack(weight)
    return (ele_num, kpt_num, eig_num), (kpts, kptw), (energy, weight)


def read_evbm(eigenval='EIGENVAL', pvalue=0.1):
    '''
    Read VBM & CBM energy and corresponding k-points. Threshold value to 
    determine unoccupied bands is allowed to assigned manually(0.1 default).

    Parameters
    ----------
    eigenval : str, optional
        Filename of EIGENVAL. The default is 'EIGENVAL'.
    pvalue : TYPE, optional
        Threshold value. The default is 0.1.

    Returns
    -------
    (e_vbm, index, k_vbm), (e_cbm, index, k_cbm), Egap

    '''
    with open(eigenval, 'r') as f:
        data = f.readlines()
    *_, kpt_num, eig_num = map(int, data[5].rstrip().split())
    kpts = np.loadtxt(StringIO(''.join(data[7::eig_num+2])))
    kpts[np.abs(kpts) < 1E-8] = 0
    energy = []
    weight = []
    wx = 1
    for i in range(eig_num):
        ei,wi = np.loadtxt(StringIO(
            ''.join(data[8+i::eig_num+2])),
            usecols=(1, 2), unpack=True)
        energy.append(ei)
        weight.append(wi)
        if wi.max() < pvalue and wx > (1-pvalue):
            break
        else:
            wx = wi.min()
    idxc = np.argmin(energy[-1])
    idxv = np.argmax(energy[-2])
    e_cbm = energy[-1][idxc]
    k_cbm = kpts[idxc]
    e_vbm = energy[-2][idxv]
    k_vbm = kpts[idxv]
    return (e_vbm, i-1, k_vbm[:3]), (e_cbm, i, k_cbm[:3]), e_cbm-e_vbm


def read_evbm_from_ne(eigenval='EIGENVAL', Ne=None, dNe=0):
    '''
    Read VBM & CBM energy from the number of electrons.

    Parameters
    ----------
    eigenval : str, optional
        Filename of EIGENVAL. The default is 'EIGENVAL'.
    Ne : TYPE, optional
        The number of electron. If None(default), read from EIGENVAL file.
    dNe : int, optional
        Additional adjustments of Ne.

    Returns
    -------
    (e_vbm, index, k_vbm), (e_cbm, index, k_cbm), Egap

    '''
    with open(eigenval, 'r') as f:
        data = f.readlines()
    e_num, kpt_num, eig_num = map(int, data[5].rstrip().split())
    
    if Ne is None:
        idxv = int((e_num + dNe)/2)  # start from 1
    else:
        idxv = int((Ne + dNe)/2)   # start from 1
    idxv, idxc = idxv-1, idxv      # start from 0
    
    e_vbms = np.loadtxt(StringIO(''.join(data[8+idxv::eig_num+2])))
    e_cbms = np.loadtxt(StringIO(''.join(data[8+idxc::eig_num+2])))
    idxvk = np.argmax(e_vbms, axis=0)[1]
    idxck = np.argmin(e_cbms, axis=0)[1]
    
    k_vbm = np.loadtxt(StringIO(data[7+idxvk*(eig_num+2)]))
    k_cbm = np.loadtxt(StringIO(data[7+idxck*(eig_num+2)]))
    
    e_vbm = e_vbms[idxvk,1]
    e_cbm = e_cbms[idxck,1]
    return (e_vbm, idxv+1, k_vbm[:3]), (e_cbm, idxc+1, k_cbm[:3]), e_cbm-e_vbm
    

def read_dos(doscar='DOSCAR', efermi=0):
    '''
    Read DOS data from DOSCAR or tdos.dat
    
    Returns
    energy, dos
    '''
    with open(doscar, 'r') as f:
        data = f.readlines()

    # detect DOSCAR
    num_i = [len(line.strip().split()) for line in data[:6]]
    if doscar == 'DOSCAR' or all([num_i[0] == 4, num_i[2] == 1, num_i[3] == 1]):
        # treat as DOSCAR file
        NEDOS = int(data[5].strip().split()[2])
        data = data[6:6+NEDOS]

    energy = []
    dos = []
    for line in data:
        if not line.startswith('#'):
            data = [float(item) for item in line.strip().split()]
            energy.append(data[0]-efermi)
            dos.append(data[1])

    return energy, dos


def read_zval(potcar='POTCAR'):
    '''
    Read ZVAL from POTCAR
    
    Returns
    -------
    z_dict: dict of {atom: zval}
    '''
    z_dict = OrderedDict()
    with open(potcar, 'r') as f:
        for line in f:
            if 'TITEL' in line:
                atom = line.split()[3].split('_')[0]
            elif 'ZVAL' in line:
                zval = float(line.split('ZVAL')[1].split()[1])
                z_dict[atom] = zval
    return z_dict

def read_transmat(transform='TRANSMAT.in'):
    '''
    Read transformation matrix from TRANSMAT.in or space-separated string.
    '''
    if isinstance(transform, str):
        if transform == 'TRANSMAT.in':
            return np.loadtxt('TRANSMAT.in', skiprows=1, usecols=(0,1,2), max_rows=3)

        try:
            trans_mat = np.array([float(v) for v in transform.split()])
        except ValueError:
            eg = "'2 0 0 0 2 0 0 0 2', '2 2 2', or '2'"
            raise ValueError(
                "Invalid transformation matrix: '{}' (allowed: {})".format(transform, eg)
            ) from None
    else:
        trans_mat = np.ravel(transform)

    nsize = trans_mat.size
    if nsize == 1:
        trans_mat = np.eye(3) * trans_mat[0]
    elif nsize == 3:
        trans_mat = np.diag(trans_mat)
    elif nsize == 9:
        trans_mat = np.reshape(trans_mat, (3, 3))
    else:
        raise ValueError(
            'Invalid transformation matrix size: {} (allowed: 1, 3, 9)'.format(nsize)
        )
    dsph = 'Transformation matrix generated by {}'.format(__prog__)
    np.savetxt('TRANSMAT', trans_mat, fmt='%6g', header=dsph, comments='')
    return trans_mat

def fix_charge(incar='INCAR', charge=0, nelect=None):
    '''
    Modify NELECT to (nelect - charge) in INCAR file.
    '''
    chg = charge if nelect is None else (nelect - charge)
    with open(incar, 'r') as f:
        lines = f.readlines()

    for idx, line in enumerate(lines):
        if 'nelect' in line.lower():
            lines[idx] = 'NELECT = {:g}\n'.format(chg)
            break
    else:
        lines.append('NELECT = {:g}\n'.format(chg))

    with open(incar, 'w') as f:
        f.writelines(lines)
